﻿// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
using System.Text;

namespace System.Security.Cryptography.Cose
{
    public sealed class CoseKey
    {
        internal CoseAlgorithm Algorithm { get; }
        internal KeyType KeyType { get; }
        internal HashAlgorithmName? HashAlgorithm { get; }

        internal RSASignaturePadding? RSASignaturePadding { get; }

        // only used for backward compatibility
        internal AsymmetricAlgorithm? AsymmetricAlgorithm { get; }

        private RSA? _rsaKey;
        private ECDsa? _ecdsaKey;

        [Experimental(Experimentals.PostQuantumCryptographyDiagId)]
        private MLDsa? _mldsaKey;

        public CoseKey(RSA key, RSASignaturePadding signaturePadding, HashAlgorithmName hashAlgorithm)
        {
            ArgumentNullException.ThrowIfNull(key);
            ArgumentNullException.ThrowIfNull(signaturePadding);
            ArgumentNullException.ThrowIfNull(hashAlgorithm.Name);

            KeyType = KeyType.RSA;
            Algorithm = GetRSAAlgorithm(signaturePadding, hashAlgorithm);

            HashAlgorithm = hashAlgorithm;
            AsymmetricAlgorithm = key;
            RSASignaturePadding = signaturePadding;
            _rsaKey = key;
        }

        public CoseKey(ECDsa key, HashAlgorithmName hashAlgorithm)
        {
            ArgumentNullException.ThrowIfNull(key);
            ArgumentNullException.ThrowIfNull(hashAlgorithm.Name);

            KeyType = KeyType.ECDsa;
            Algorithm = GetECDsaAlgorithm(hashAlgorithm);

            HashAlgorithm = hashAlgorithm;
            AsymmetricAlgorithm = key;
            _ecdsaKey = key;
        }

        [Experimental(Experimentals.PostQuantumCryptographyDiagId)]
        public CoseKey(MLDsa key)
        {
            ArgumentNullException.ThrowIfNull(key);

            KeyType = KeyType.MLDsa;
            Algorithm = GetMLDsaAlgorithm(key.Algorithm);

            _mldsaKey = key;
        }

        internal static CoseKey FromUntrustedAlgorithmAndKey(CoseAlgorithm untrustedAlgorithm, IDisposable key)
        {
            if (key is ECDsa ecdsaKey)
            {
                return untrustedAlgorithm switch
                {
                    CoseAlgorithm.ES256 => new CoseKey(ecdsaKey, HashAlgorithmName.SHA256),
                    CoseAlgorithm.ES384 => new CoseKey(ecdsaKey, HashAlgorithmName.SHA384),
                    CoseAlgorithm.ES512 => new CoseKey(ecdsaKey, HashAlgorithmName.SHA512),
                    _ => throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, untrustedAlgorithm))
                };
            }
            else if (key is RSA rsaKey)
            {
                return untrustedAlgorithm switch
                {
                    CoseAlgorithm.RS256 => new CoseKey(rsaKey, RSASignaturePadding.Pkcs1, HashAlgorithmName.SHA256),
                    CoseAlgorithm.RS384 => new CoseKey(rsaKey, RSASignaturePadding.Pkcs1, HashAlgorithmName.SHA384),
                    CoseAlgorithm.RS512 => new CoseKey(rsaKey, RSASignaturePadding.Pkcs1, HashAlgorithmName.SHA512),
                    CoseAlgorithm.PS256 => new CoseKey(rsaKey, RSASignaturePadding.Pss, HashAlgorithmName.SHA256),
                    CoseAlgorithm.PS384 => new CoseKey(rsaKey, RSASignaturePadding.Pss, HashAlgorithmName.SHA384),
                    CoseAlgorithm.PS512 => new CoseKey(rsaKey, RSASignaturePadding.Pss, HashAlgorithmName.SHA512),
                    _ => throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, untrustedAlgorithm))
                };
            }
#pragma warning disable SYSLIB5006
            else if (key is MLDsa mldsaKey)
            {
                return untrustedAlgorithm switch
                {
                    CoseAlgorithm.MLDsa44 => FromKeyWithExpectedAlgorithm(MLDsaAlgorithm.MLDsa44, mldsaKey),
                    CoseAlgorithm.MLDsa65 => FromKeyWithExpectedAlgorithm(MLDsaAlgorithm.MLDsa65, mldsaKey),
                    CoseAlgorithm.MLDsa87 => FromKeyWithExpectedAlgorithm(MLDsaAlgorithm.MLDsa87, mldsaKey),
                    _ => throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, untrustedAlgorithm))
                };

                CoseKey FromKeyWithExpectedAlgorithm(MLDsaAlgorithm expected, MLDsa key)
                    => key.Algorithm.Name == expected.Name ? new CoseKey(key) : throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, untrustedAlgorithm));
            }
#pragma warning restore SYSLIB5006
            else
            {
                throw new ArgumentException(SR.Format(SR.Sign1UnsupportedKey, key.GetType().Name), nameof(key));
            }
        }

        internal static CoseAlgorithm CoseAlgorithmFromInt64(long alg)
        {
            if (alg >= 0 || alg < int.MinValue)
            {
                throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, alg));
            }

            CoseAlgorithm coseAlgorithm = (CoseAlgorithm)alg;
            ThrowIfCoseAlgorithmNotSupported(coseAlgorithm);

            return coseAlgorithm;
        }

        private static void ThrowIfCoseAlgorithmNotSupported(CoseAlgorithm alg)
        {
#pragma warning disable SYSLIB5006
            if (alg is not CoseAlgorithm.ES256 and
                not CoseAlgorithm.ES384 and
                not CoseAlgorithm.ES512 and
                not CoseAlgorithm.PS256 and
                not CoseAlgorithm.PS384 and
                not CoseAlgorithm.PS512 and
                not CoseAlgorithm.RS256 and
                not CoseAlgorithm.RS384 and
                not CoseAlgorithm.RS512 and
                not CoseAlgorithm.MLDsa44 and
                not CoseAlgorithm.MLDsa65 and
                not CoseAlgorithm.MLDsa87)
            {
                throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, alg));
            }
#pragma warning restore SYSLIB5006
        }

        internal static CoseAlgorithm CoseAlgorithmFromString(string algString)
        {
            // https://www.iana.org/assignments/cose/cose.xhtml#algorithms
            return algString switch
            {
                nameof(CoseAlgorithm.ES256) => CoseAlgorithm.ES256,
                nameof(CoseAlgorithm.ES384) => CoseAlgorithm.ES384,
                nameof(CoseAlgorithm.ES512) => CoseAlgorithm.ES512,
                nameof(CoseAlgorithm.PS256) => CoseAlgorithm.PS256,
                nameof(CoseAlgorithm.PS384) => CoseAlgorithm.PS384,
                nameof(CoseAlgorithm.PS512) => CoseAlgorithm.PS512,
                nameof(CoseAlgorithm.RS256) => CoseAlgorithm.RS256,
                nameof(CoseAlgorithm.RS384) => CoseAlgorithm.RS384,
                nameof(CoseAlgorithm.RS512) => CoseAlgorithm.RS512,
#pragma warning disable SYSLIB5006
                "ML-DSA-44" => CoseAlgorithm.MLDsa44,
                "ML-DSA-65" => CoseAlgorithm.MLDsa65,
                "ML-DSA-87" => CoseAlgorithm.MLDsa87,
#pragma warning restore SYSLIB5006
                _ => throw new CryptographicException(SR.Format(SR.Sign1UnknownCoseAlgorithm, algString))
            };
        }

        internal int ComputeSignatureSize()
        {
            switch (KeyType)
            {
                case KeyType.ECDsa:
                    return 2 * ((_ecdsaKey!.KeySize + 7) / 8);
                case KeyType.RSA:
                    return (_rsaKey!.KeySize + 7) / 8;
#pragma warning disable SYSLIB5006
                case KeyType.MLDsa:
                    return _mldsaKey!.Algorithm.SignatureSizeInBytes;
#pragma warning restore SYSLIB5006
                default:
                    Debug.Fail($"Unknown key type: {KeyType}");
                    throw new CryptographicException(SR.Format(SR.Sign1UnsupportedKey, KeyType.ToString()));
            }
        }

        internal int Sign(ReadOnlySpan<byte> toBeSigned, Span<byte> destination)
        {
            switch (KeyType)
            {
                case KeyType.ECDsa:
                    Debug.Assert(_ecdsaKey != null);
                    return SignHashWithECDsa(_ecdsaKey, toBeSigned, destination);
                case KeyType.RSA:
                    Debug.Assert(_rsaKey != null);
                    Debug.Assert(HashAlgorithm != null);
                    Debug.Assert(RSASignaturePadding != null);
                    return SignHashWithRSA(_rsaKey, toBeSigned, HashAlgorithm.Value, RSASignaturePadding, destination);
#pragma warning disable SYSLIB5006
                case KeyType.MLDsa:
                    Debug.Assert(_mldsaKey != null);
                    Span<byte> mldsaSignature = destination.Slice(0, _mldsaKey.Algorithm.SignatureSizeInBytes);
                    _mldsaKey.SignData(toBeSigned, mldsaSignature);
                    return mldsaSignature.Length;
#pragma warning restore SYSLIB5006
                default:
                    Debug.Fail("Unknown key type");
                    throw new CryptographicException(SR.Format(SR.Sign1UnsupportedKey, KeyType.ToString()));
            }
        }

        private static int SignHashWithECDsa(ECDsa key, ReadOnlySpan<byte> hash, Span<byte> destination)
        {
            if (!key.TrySignHash(hash, destination, out int bytesWritten))
            {
                Debug.Fail("TrySignData failed with a pre-calculated destination");
                throw new CryptographicException();
            }

            return bytesWritten;
        }

        private static int SignHashWithRSA(RSA key, ReadOnlySpan<byte> hash, HashAlgorithmName hashAlgorithm, RSASignaturePadding padding, Span<byte> destination)
        {
            if (!key.TrySignHash(hash, destination, hashAlgorithm, padding, out int bytesWritten))
            {
                Debug.Fail("TrySignData failed with a pre-calculated destination");
                throw new CryptographicException();
            }

            return bytesWritten;
        }

        internal bool Verify(ReadOnlySpan<byte> toBeSigned, ReadOnlySpan<byte> signature)
        {
            switch (KeyType)
            {
                case KeyType.ECDsa:
                {
                    Debug.Assert(_ecdsaKey != null);
                    return _ecdsaKey.VerifyHash(toBeSigned, signature);
                }
                case KeyType.RSA:
                {
                    Debug.Assert(_rsaKey != null);
                    Debug.Assert(RSASignaturePadding != null);
                    Debug.Assert(HashAlgorithm != null);
                    return _rsaKey.VerifyHash(toBeSigned, signature, HashAlgorithm.Value, RSASignaturePadding);
                }
#pragma warning disable SYSLIB5006
                case KeyType.MLDsa:
                {
                    Debug.Assert(_mldsaKey != null);
                    return _mldsaKey.VerifyData(toBeSigned, signature);
                }
#pragma warning restore SYSLIB5006
                default:
                {
                    Debug.Fail($"Unknown keyType: {KeyType}");
                    throw new ArgumentException(SR.Sign1UnsupportedKey, KeyType.ToString());
                }
            }
        }

        private static CoseAlgorithm GetRSAAlgorithm(RSASignaturePadding signaturePadding, HashAlgorithmName hashAlgorithm)
        {
            Debug.Assert(signaturePadding != null);

            if (signaturePadding == RSASignaturePadding.Pss)
            {
                return hashAlgorithm.Name switch
                {
                    nameof(HashAlgorithmName.SHA256) => CoseAlgorithm.PS256,
                    nameof(HashAlgorithmName.SHA384) => CoseAlgorithm.PS384,
                    nameof(HashAlgorithmName.SHA512) => CoseAlgorithm.PS512,
                    _ => throw new ArgumentException(SR.Format(SR.Sign1SignUnsupportedHashAlgorithm, hashAlgorithm.Name), nameof(hashAlgorithm))
                };
            }

            Debug.Assert(signaturePadding == RSASignaturePadding.Pkcs1);

            return hashAlgorithm.Name switch
            {
                nameof(HashAlgorithmName.SHA256) => CoseAlgorithm.RS256,
                nameof(HashAlgorithmName.SHA384) => CoseAlgorithm.RS384,
                nameof(HashAlgorithmName.SHA512) => CoseAlgorithm.RS512,
                _ => throw new ArgumentException(SR.Format(SR.Sign1SignUnsupportedHashAlgorithm, hashAlgorithm.Name), nameof(hashAlgorithm))
            };
        }

        private static CoseAlgorithm GetECDsaAlgorithm(HashAlgorithmName hashAlgorithm)
        {
            return hashAlgorithm.Name switch
            {
                nameof(HashAlgorithmName.SHA256) => CoseAlgorithm.ES256,
                nameof(HashAlgorithmName.SHA384) => CoseAlgorithm.ES384,
                nameof(HashAlgorithmName.SHA512) => CoseAlgorithm.ES512,
                _ => throw new ArgumentException(SR.Format(SR.Sign1SignUnsupportedHashAlgorithm, hashAlgorithm.Name), nameof(hashAlgorithm))
            };
        }

        [Experimental(Experimentals.PostQuantumCryptographyDiagId)]
        private static CoseAlgorithm GetMLDsaAlgorithm(MLDsaAlgorithm algorithm)
        {
            if (algorithm.Name == MLDsaAlgorithm.MLDsa44.Name)
            {
                return CoseAlgorithm.MLDsa44;
            }
            else if (algorithm.Name == MLDsaAlgorithm.MLDsa65.Name)
            {
                return CoseAlgorithm.MLDsa65;
            }
            else if (algorithm.Name == MLDsaAlgorithm.MLDsa87.Name)
            {
                return CoseAlgorithm.MLDsa87;
            }
            else
            {
                throw new ArgumentException(SR.Format(SR.Sign1UnknownCoseAlgorithm, algorithm.Name), "key");
            }
        }

        internal ToBeSignedBuilder CreateToBeSignedBuilder()
        {
            switch (KeyType)
            {
                case KeyType.MLDsa:
                    return new PureDataToBeSignedBuilder();
                default:
                    Debug.Assert(HashAlgorithm != null);
                    return new HashToBeSignedBuilder(HashAlgorithm.Value);
            }
        }
    }
}
